package com.freetymekiyan.algorithms.level.medium;

import com.freetymekiyan.algorithms.utils.Utils;
import com.freetymekiyan.algorithms.utils.Utils.TreeNode;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

import java.util.ArrayDeque;
import java.util.Deque;

/**
 * 230. Kth Smallest Element in a BST
 * <p>
 * Given a binary search tree, write a function kthSmallest to find the kth smallest element in it.
 * <p>
 * Note:
 * You may assume k is always valid, 1 ≤ k ≤ BST's total elements.
 * <p>
 * Follow up:
 * What if the BST is modified (insert/delete operations) often and you need to find the kth smallest frequently? How
 * would you optimize the kthSmallest routine?
 * <p>
 * Hint:
 * <p>
 * Try to utilize the property of a BST.
 * What if you could modify the BST node's structure?
 * The optimal runtime complexity is O(height of BST).
 * <p>
 * Tags: Binary Search Tree
 * Similar Problems: (M) Binary Tree Inorder Traversal
 * <p>
 * Answers:
 * If the BST is modified often, and I can modify the BST node's structure, I can add the left subtree node count to the
 * each node to know its ranking. Then use binary search to find the target node.
 */
public class KthSmallestElementInABst {

  private KthSmallestElementInABst k;

  private int res;
  private int count;

  /**
   * Recursive solution with in-order traversal helper.
   */
  public int kthSmallest(TreeNode root, int k) {
    count = k;
    traverse(root);
    return res;
  }

  private void traverse(TreeNode node) {
    if (node.left != null) {
      traverse(node.left);
    }
    count--;
    if (count == 0) {
      res = node.val;
      return;
    }
    if (node.right != null) {
      traverse(node.right);
    }
  }

  /**
   * Iterasive solution with stack.
   */
  public int kthSmallestB(TreeNode root, int k) {
    Deque<TreeNode> stack = new ArrayDeque<>();
    int count = k;
    while (!stack.isEmpty() || root != null) {
      if (root != null) {
        stack.push(root);
        root = root.left;
      } else {
        root = stack.pop();
        count--;
        if (count == 0) {
          return root.val;
        }
        root = root.right;
      }
    }
    return -1;
  }

  /**
   * Binary search for left subtree node count.
   * * For BST, the # of nodes of left subtree is actually the node's ranking.
   */
  public int kthSmallestC(TreeNode root, int k) {
    int count = countNodes(root.left);
    if (k <= count) {
      return kthSmallest(root.left, k);
    } else if (k > count + 1) {
      return kthSmallest(root.right, k - 1 - count); // 1 is counted as current node
    }
    return root.val;
  }

  /**
   * Count how many nodes in this subtree rooted from n.
   * If we can modify the data structure, we can save the count with each node.
   */
  private int countNodes(TreeNode n) {
    if (n == null) {
      return 0;
    }
    return 1 + countNodes(n.left) + countNodes(n.right);
  }

  @Before
  public void setUp() {
    k = new KthSmallestElementInABst();
  }

  @Test
  public void testExamples() {
    TreeNode root = Utils.buildBinaryTree(new Integer[]{1});
    Assert.assertEquals(1, k.kthSmallest(root, 1));
    Assert.assertEquals(1, k.kthSmallestB(root, 1));
    Assert.assertEquals(1, k.kthSmallestC(root, 1));
    root = Utils.buildBinaryTree(new Integer[]{2, 1});
    Assert.assertEquals(1, k.kthSmallest(root, 1));
    Assert.assertEquals(1, k.kthSmallestB(root, 1));
    Assert.assertEquals(1, k.kthSmallestC(root, 1));
    root = Utils.buildBinaryTree(new Integer[]{1, -1, 2, null, null, null, 3});
    Assert.assertEquals(3, k.kthSmallest(root, 4));
    Assert.assertEquals(3, k.kthSmallestB(root, 4));
    Assert.assertEquals(3, k.kthSmallestC(root, 4));
  }

  @After
  public void tearDown() {
    k = null;
  }

}
